
(* this is an incomplete implementation of a DNS resolver *)
module Dnsresolver = 

type packet = 
    {
      packetdata: string;
      destaddr: Posix.sockaddr;
    }

type transactionstate = Pending | RequestSent | NameServerFailed | Failed
type dnstype = A | CNAME | PTR
type dnstransaction =
    {
      socketorig: int;
      hostname: string;
      typerequest: dnstype;
      nameservers: list string;
      mutable resolvestate: transactionstate;
      mutable currentnameserver: int;
      readqueue: list string;
      writequeue: list packet;
      mutable timerlasttry: int64;
      timecreated: int64;
    }
type dnsansweritem =
    {
      name: string;
      typeanswer: dnstype;
      ipaddress: string;
      realname: string; (* only for CNAME *)
      timeadded: int64;
    }

let cachetimeout = 3600L (* the dns resolver has a cache of 1h, regardless of everything else *)
let transactionTimeout = 20L
let timerDoTasks = ref 0L
let timerDoTasksInterval = 1L
let tryInterval = 2L

let transactions = ref []
let cachedresults = ref []
let readbuf = String.build 66000

let log msg= 
  let (sec, nsec) = Posix.clock_gettime_monotonic () in 
  printf "[%Li.%09i] %s\n" sec nsec msg

let getnormalizedip ip =
  let l = List.rev ("arpa" :: "in-addr" :: (String.split ip '.')) in
  case l 
  | t :: q -> List.fold (fun a b -> a ^ "." ^ b) t q
  | _  -> "" (* cannot happen *)
  end

let isipaddressv4 s = 
  let parts = String.split s '.' in 
  List.length parts = 4 && List.forallof (fun part -> String.to_int part < 256) parts

let getipfromname name = 
  if isipaddressv4 name then
    Some name
  else if List.exists (fun r -> r.name = name) !cachedresults then
    let relevantresults = List.filter (fun r -> r.name = name) !cachedresults in
    let a_results = List.filter (fun r -> r.typeanswer = A || r.typeanswer = PTR) relevantresults in 
    let cname_results = List.filter (fun r -> r.typeanswer = CNAME) relevantresults in 
    if a_results <> [] then
      let i = Random.int_below (List.length a_results) in 
      let r = List.nth i a_results in 
      Some r.ipaddress
    else if cname_results <> [] then
      case cname_results
      | [] -> None
      | head :: _ -> 
        let new_results = List.filter (fun r -> r.name = head.realname && (r.typeanswer = A || r.typeanswer = PTR)) !cachedresults in 
        if new_results <> [] then
          let i = Random.int_below (List.length new_results) in 
          let r = List.nth i new_results in 
          Some r.ipaddress
        else
          None
      end
    else
      None
  else
    None

let closeTransaction trans =
  let _ = Posix.close trans.socketorig in 
  transactions := List.filter ((<>) trans) !transactions


let createRequest hostname typereq =
  let transid = "\x00\x01" in
  let flags = "\x01\x00" in
  let questions = "\x00\x01" in
  let answerRRs = "\x00\x00" in
  let authorityRRs = "\x00\x00" in
  let additionalRRs = "\x00\x00" in
  let typereqstring = case typereq | A -> "\x00\x01" | PTR -> "\x00\x0C" | _ -> "" end in
  let classreq = "\x00\x01" in
  let parts = String.split (if typereq = A then hostname else getnormalizedip hostname) '.' in
    transid ^ flags ^ questions ^ answerRRs ^ authorityRRs ^ additionalRRs ^ 
    (List.fold (fun a b -> a ^ (sprintf "%c" (char_of_int (String.length b))) ^ b) "" parts) ^ 
    "\x00" ^ typereqstring ^ classreq

let decodednsanswer s = 
  let decodeitem s1 p = String.sub s1 (p+1) (int_of_char s1.[p]), int_of_char s1.[p] + 1 in    
  let rec parsehostname s i = (* on suppose que la chaine de depart est non nulle *) (* TODO: specially crafted packets can make it loop *)
    case s.[i] 
      | '\x00' -> []
      | _ -> 
        if (int_of_char s.[i]) land 0xC0 = 0xC0 then
          let offset = ((int_of_char s.[i]) land 0x3F) * 256 + int_of_char s.[i+1] in
          parsehostname s offset
        else
          let (item, p) = decodeitem s i in
          item :: (parsehostname s (i+p))
    end in
  let pos = ref 12 in
  (*let comps = parsehostname s !pos in
  let queryhostname = List.fold_left (fun b a -> b ^ "." ^ a) (List.hd comps) (List.tl comps) in
  assert (printf "query hostname=%s\n" queryhostname; true);*)
  let () = pos := if (int_of_char s.[!pos]) land 0xC0 = 0xC0 then !pos+2 else case String.indexoffrom "\x00" s !pos | Some p -> p | None -> !pos (* TODO: we should return an error *) end + 1 in 
  let () = pos := !pos + 4 in 
  let replies = ref [] in
  let () = 
    while (fun () -> !pos < String.length s - 1) 
          (fun () ->
    let answerhostname = let comps = parsehostname s !pos in 
                         case comps
                         | h :: t -> List.fold (fun b a -> b ^ "." ^ a) h t
                         | [] -> "" (* TODO: should return an error *)
                         end in
    let () = pos := if (int_of_char s.[!pos]) land 0xC0 = 0xC0 then !pos+2 else case String.indexoffrom "\x00" s !pos | Some p -> p | None -> !pos (* TODO: return an error *) end + 1 in 
    let afield = s.[!pos] = '\x00' && s.[!pos+1] = '\x01' in 
    let cnamefield = s.[!pos] = '\x00' && s.[!pos+1] = '\x05' in 
    let ptrfield = s.[!pos] = '\x00' && s.[!pos+1] = '\x0C' in
    let () = pos := !pos + 8 in 
    let datalength = (int_of_char s.[!pos]) * 256 + int_of_char s.[!pos+1] in
    let () = pos := !pos + 2 in 
    let () = 
    if afield then (* A *)
      let conv c = sprintf "%i" (int_of_char c) in
      let ip = (conv s.[!pos]) ^ "." ^ (conv s.[!pos+1]) ^ "." ^ (conv s.[!pos+2]) ^ "." ^ (conv s.[!pos+3]) in
      replies := { name = answerhostname; typeanswer = A; ipaddress = ip; realname = ""; timeadded = fst (Posix.clock_gettime_monotonic ()); } :: !replies
    else if cnamefield then (* CNAME *)
      let comps = parsehostname s !pos in
      case comps
      | h :: t -> replies := { name = answerhostname; typeanswer = CNAME; ipaddress = List.fold (fun b a -> b ^ "." ^ a) h t; 
                   realname = String.join comps "."; timeadded = fst (Posix.clock_gettime_monotonic ()); } :: !replies
      | [] -> () (* TODO: should return an error *)
      end
    else if ptrfield then (* PTR *)
      let comps = parsehostname s !pos in 
      case comps
      | h :: t -> replies := { name = answerhostname; typeanswer = PTR; ipaddress = List.fold (fun b a -> b ^ "." ^ a) h t; 
                    realname = ""; timeadded = fst (Posix.clock_gettime_monotonic ()); } :: !replies
      | [] -> () (* TODO: should return an error *)
      end
    else
      () in 
    pos := !pos + datalength) in 
  !replies

let doRequest trans =
  if trans.currentnameserver < List.length trans.nameservers then
    let ipaddropt = Some (Posix.in_addr_of_string (List.nth trans.currentnameserver trans.nameservers)) in 
    case ipaddropt  
       | Some ipaddr ->
         let destaddr = Posix.ADDR_INET(ipaddr, 53) in
         let data = createRequest trans.hostname trans.typerequest in
         let packet = { packetdata = data; destaddr = destaddr; } in
         trans.writequeue <- List.append trans.writequeue [packet]
       | None -> 
         let () = printf "ipaddress of nameserver not good\n" in 
         trans.resolvestate <- NameServerFailed
    end
  else
    let () = printf "no more nameservers to ask\n" in 
    trans.resolvestate <- Failed

let createResolveTransaction hostname typetrans nameservers = 
  let () = log (sprintf "Dnsresolver: creating resolve transaction for hostname %s" hostname) in 
  let sock = Posix.socket Posix.AF_INET Posix.SOCK_DGRAM 0 in 
  if sock < 0 then
    ()
  else
      (* TODO: il est preferable de prendre les 5 permiers nameservers, faire un shuffle, et partir de la *)
      let trans = { socketorig = sock; hostname = hostname; typerequest = typetrans; nameservers = nameservers; resolvestate = Pending; 
                     currentnameserver = 0; readqueue = []; writequeue = []; 
                      timerlasttry = fst (Posix.clock_gettime_monotonic ()); timecreated = fst (Posix.clock_gettime_monotonic ()); } in
      let () = transactions := trans :: !transactions in 
      doRequest trans

let write_handler sock = 
  let mytranslist = List.filter (fun trans -> trans.socketorig = sock) !transactions in
  case mytranslist 
  | mytrans :: t -> 
    let packet = case mytrans.writequeue 
                 | h :: t -> let () = mytrans.writequeue <- t in 
                             h 
                 | [] -> { packetdata = ""; destaddr = Posix.ADDR_INET((Posix.in_addr_of_string ""), 53); } (* cannot happen *)
                 end in
    let _ = Posix.connect mytrans.socketorig packet.destaddr in 
    let len = Posix.send mytrans.socketorig packet.packetdata 0 (String.length packet.packetdata) in 
    if len > 0 then
      mytrans.resolvestate <- RequestSent 
    else
      mytrans.resolvestate <- Failed
  | [] -> ()
  end

let read_handler sock = 
  let () = log "in Dnsresolver.read_handler" in 
  let len = Posix.recv sock readbuf 0 66000 in
  if len > 0 then 
      let payload = String.sub readbuf 0 len in
      let replies = decodednsanswer payload in 
      let () = List.iter (fun reply -> printf "type %s %s %s\n" (case reply.typeanswer | A -> "A" | CNAME -> "CNAME" | PTR -> "PTR" end) 
                                         reply.name reply.ipaddress) replies in 
      let () = List.iter (fun reply -> if not (List.mem reply !cachedresults) then cachedresults := reply :: !cachedresults else ()) replies in 
      List.iter (fun reply -> List.iter (fun trans -> if trans.hostname = reply.name then closeTransaction trans else ()) !transactions) replies
  else
     let mytranslist = List.filter (fun trans -> trans.socketorig = sock) !transactions in
     case mytranslist 
     | h :: _ -> h.resolvestate <- Failed 
     | [] -> ()
     end

let manageTransaction trans = 
  let (now_sec, _) = Posix.clock_gettime_monotonic () in 
  if Int64.compare (Int64.sub now_sec trans.timecreated) transactionTimeout > 0 then 
    closeTransaction trans
  else if Int64.compare (Int64.sub now_sec trans.timerlasttry) tryInterval > 0 && trans.resolvestate = RequestSent then 
    let () = trans.resolvestate <- NameServerFailed in 
    trans.timerlasttry <- now_sec
  else
    case trans.resolvestate 
      | NameServerFailed -> let () = trans.currentnameserver <- trans.currentnameserver + 1 in 
                            doRequest trans
      | _ -> ()
    end

let isSocketHere sock =
  List.exists (fun trans -> trans.socketorig = sock) !transactions

let getAllSockets () = 
  List.map (fun trans -> (trans.socketorig, if trans.writequeue <> [] then [| Posix.POLLIN; Posix.POLLOUT |] else [| Posix.POLLIN |] )) !transactions
  
let doTasks () = 
  let () = log "in dnsresolver doTasks" in 
  let (now_sec, _) = Posix.clock_gettime_monotonic () in 
  if Int64.compare (Int64.sub now_sec !timerDoTasks) timerDoTasksInterval > 0 then
    let () = timerDoTasks := now_sec in 
    let () = List.iter manageTransaction !transactions in 
    cachedresults := List.filter (fun r -> Int64.compare (Int64.sub now_sec r.timeadded) cachetimeout > 0) !cachedresults
  else
    ()


endmodule
